package host.springboot.framework.jwt.util;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTCreator;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.exceptions.SignatureVerificationException;
import com.auth0.jwt.exceptions.TokenExpiredException;
import com.auth0.jwt.interfaces.DecodedJWT;
import host.springboot.framework.core.util.Assert;
import host.springboot.framework.jwt.payload.JwtPayload;
import org.jspecify.annotations.NonNull;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Date;
import java.time.ZoneId;
import java.util.Objects;

/**
 * JWT工具类
 *
 * <p><a href="https://github.com/auth0/java-jwt">点击查看官方文档</a>
 *
 * @author JiYinchuan
 * @since 1.0.0
 */
public class JwtUtils {

    private static final Logger LOGGER = LoggerFactory.getLogger(JwtUtils.class);

    private static final String LOG_TAG = "KS-JWT-Utils";

    /**
     * 生成Token
     *
     * @param kid     keyId
     * @param secret  密钥, 默认采用哈希256加密
     * @param payload Jwt载荷
     * @return JwtToken
     * @since 1.0.0
     */
    public static @NonNull String createToken(
            @Nullable String kid,
            @NonNull String secret,
            @NonNull JwtPayload payload) {
        return createToken(kid, Algorithm.HMAC256(secret), payload);
    }

    /**
     * 生成Token
     *
     * @param kid       keyId
     * @param algorithm 密钥算法
     * @param payload   Jwt载荷
     * @return JwtToken
     * @since 1.0.0
     */
    public static @NonNull String createToken(
            @Nullable String kid,
            @NonNull Algorithm algorithm,
            @NonNull JwtPayload payload) {
        Assert.notNull(algorithm, "JWT algorithm must not be null");
        Assert.notNull(payload, "JWT payload must not be null");

        JWTCreator.Builder builder = JWT.create();
        if (Objects.nonNull(kid)) {
            builder.withKeyId(kid);
        }
        return builder.withIssuer(payload.getIss()).withSubject(payload.getSub())
                .withExpiresAt(Date.from(payload.getExp().atZone(ZoneId.systemDefault()).toInstant()))
                .withNotBefore(Date.from(payload.getNbf().atZone(ZoneId.systemDefault()).toInstant()))
                .withIssuedAt(Date.from(payload.getIat().atZone(ZoneId.systemDefault()).toInstant()))
                .withJWTId(payload.getJti())
                .withAudience(payload.getAud().toArray(new String[0]))
                .withClaim(JwtPayload.USER_ID, payload.getUserId())
                .withClaim(JwtPayload.INTERNAL, payload.getInternal())
                .sign(algorithm);
    }

    /**
     * 校验Token
     *
     * @param secret 密钥
     * @param token  Token
     * @return [true: 验证成功, false: 验证失败]
     * @since 1.0.0
     */
    public static boolean verifyToken(@NonNull String secret, @NonNull String token) {
        Assert.notNull(secret, "JWT secret must not be null");
        Assert.notNull(token, "JWT token must not be null");

        try {
            decodeTokenGet(secret, token);
            return true;
        } catch (SignatureVerificationException signatureVerificationException) {
            LOGGER.debug("[{}] Token签名异常 [Secret: {}, Token: {}]", LOG_TAG, secret, token);
            return false;
        } catch (TokenExpiredException tokenExpiredException) {
            LOGGER.debug("[{}] Token已过期 [Secret: {}, Token: {}]", LOG_TAG, secret, token);
            return false;
        }
    }

    /**
     * 解析Token
     *
     * @param secret 密钥
     * @param token  Token
     * @return JwtPayload
     * @throws SignatureVerificationException 签名验证异常
     * @throws TokenExpiredException          Token过期异常
     * @throws JWTDecodeException             Token解析失败
     * @since 1.0.0
     */
    public static @NonNull JwtPayload decodeTokenGet(@NonNull String secret, @NonNull String token)
            throws SignatureVerificationException, TokenExpiredException {
        Assert.notNull(secret, "JWT secret must not be null");
        Assert.notNull(token, "JWT token must not be null");

        Algorithm algorithm = Algorithm.HMAC256(secret);
        JWTVerifier verifier = JWT.require(algorithm).build();
        DecodedJWT jwt = verifier.verify(token);
        return new JwtPayload()
                .setIss(jwt.getIssuer())
                .setSub(jwt.getSubject())
                .setExp(jwt.getExpiresAt().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime())
                .setNbf(jwt.getNotBefore().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime())
                .setIat(jwt.getIssuedAt().toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime())
                .setJti(jwt.getId())
                .setAud(jwt.getAudience())
                .setUserId(jwt.getClaim(JwtPayload.USER_ID).asLong())
                .setInternal(jwt.getClaim(JwtPayload.INTERNAL).asBoolean());
    }
}
